{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import DCLLoss\n",
    "from lightly.models.modules import SimCLRProjectionHead\n",
    "from lightly.transforms.simclr_transform import SimCLRTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DCL(nn.Module):\n",
    "    def __init__(self, backbone):\n",
    "        super().__init__()\n",
    "        self.backbone = backbone\n",
    "        self.projection_head = SimCLRProjectionHead(512, 512, 128)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x).flatten(start_dim=1)\n",
    "        z = self.projection_head(x)\n",
    "        return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = torchvision.models.resnet18()\n",
    "backbone = nn.Sequential(*list(resnet.children())[:-1])\n",
    "model = DCL(backbone)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = SimCLRTransform(input_size=32)\n",
    "dataset = torchvision.datasets.CIFAR10(\n",
    "    \"datasets/cifar10\", download=True, transform=transform\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = DCLLoss()\n",
    "# or use the weighted DCLW loss:\n",
    "# criterion = DCLWLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting Training\")\n",
    "for epoch in range(10):\n",
    "    total_loss = 0\n",
    "    for batch in dataloader:\n",
    "        x0, x1 = batch[0]\n",
    "        x0 = x0.to(device)\n",
    "        x1 = x1.to(device)\n",
    "        z0 = model(x0)\n",
    "        z1 = model(x1)\n",
    "        loss = criterion(z0, z1)\n",
    "        total_loss += loss.detach()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "    avg_loss = total_loss / len(dataloader)\n",
    "    print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
